{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 统计分析\n",
    "\n",
    "通过指定统计分析字段，得到每个特征的p_value，所有的p_value计算都是基于Ttest计算。支持指定不同的分组`group`，例如train、val、test等分组统计。\n",
    "\n",
    "对于两大类不同的特征\n",
    "\n",
    "1. 离散特征，统计数量以及占比。\n",
    "2. 连续特征，统计均值、方差。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from onekey_algo import OnekeyDS as okds\n",
    "from onekey_algo import get_param_in_cwd\n",
    "from onekey_algo.custom.utils import print_join_info\n",
    "\n",
    "task = get_param_in_cwd('task_column') or 'label'\n",
    "p_value = get_param_in_cwd('p_value') or 0.05\n",
    "# 修改成自己临床数据的文件。\n",
    "test_data = pd.read_csv(get_param_in_cwd('clinic_file'))\n",
    "stats_columns_settings = get_param_in_cwd('stats_columns')\n",
    "continuous_columns_settings = get_param_in_cwd('continuous_columns')\n",
    "mapping_columns_settings = get_param_in_cwd('mapping_columns')\n",
    "test_data = test_data[[c for c in test_data.columns if c != task]]\n",
    "test_data['ID'] = test_data['ID'].map(lambda x: f\"{x}.nii.gz\" if not (f\"{x}\".endswith('.nii.gz') or  f\"{x}\".endswith('.nii')) else x)\n",
    "group_info = pd.read_csv(get_param_in_cwd('label_file'))\n",
    "print_join_info(test_data, group_info)\n",
    "test_data = pd.merge(test_data, group_info, on='ID', how='inner')\n",
    "test_data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 特征名称处理\n",
    "\n",
    "去掉所有特征名称中的特殊字符。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "\n",
    "def map_cnames(x):\n",
    "    x = re.split('[（|(]', x)[0]\n",
    "    x = x.replace('-', '_').replace(' ', '_').replace('>', '').replace('/', '_')\n",
    "    return x.strip()\n",
    "\n",
    "test_data.columns = list(map(map_cnames, test_data.columns))\n",
    "test_data.columns"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 分析数据\n",
    "\n",
    "获取待分析的特征列名，如未制定，自动侦测。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stats_columns = stats_columns_settings or list(test_data.columns[1:-2])\n",
    "test_data = test_data.copy()[['ID'] + stats_columns + ['group', 'label']]\n",
    "test_data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 特征队列映射\n",
    "\n",
    "所有需要进行特征映射的队列，range未制定，可以进行自动判断。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mapping_columns = mapping_columns_settings or [c for c in test_data.columns[1:-2] if test_data[c].dtype == object]\n",
    "mapping_columns"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 数据映射\n",
    "\n",
    "针对所有非数值形式的数据，可以进行类别映射。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from onekey_algo.custom.utils import map2numerical\n",
    "\n",
    "data, mapping = map2numerical(test_data, mapping_columns=mapping_columns)\n",
    "mapping"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from onekey_algo.custom.components.comp1 import fillna\n",
    "\n",
    "data = fillna(data, fill_mod='50%')\n",
    "continuous_columns = []\n",
    "for col in test_data.columns:\n",
    "    if test_data[col].apply(lambda x: x.is_integer() if isinstance(x, float) else False).all():\n",
    "        test_data[col] = test_data[col].astype(int)\n",
    "\n",
    "for c in stats_columns:\n",
    "#     print(c, np.unique(test_data[c]), test_data[c].dtype)\n",
    "    if len(np.unique(test_data[c])) > 5 or not np.int8 <= test_data[c].dtype <= np.int64:\n",
    "        continuous_columns.append(c)\n",
    "        \n",
    "continuous_columns = continuous_columns_settings or continuous_columns\n",
    "continuous_columns = [c for c in continuous_columns if c not in ('differentation')]\n",
    "continuous_columns"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 缺失值填充"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.makedirs('data', exist_ok=True)\n",
    "data.to_csv('data/clinical.csv', index=False)\n",
    "data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 统计分析\n",
    "\n",
    "支持两种格式数据，分别对应`pretty`参数的`True`和`False`, 当为`True`时，输出的是表格模式，反之则为dict数据。\n",
    "\n",
    "```python\n",
    "def clinic_stats(data: DataFrame, stats_columns: Union[str, List[str]], label_column='label',\n",
    "                 group_column: str = None, continuous_columns: Union[str, List[str]] = None,\n",
    "                 pretty: bool = True) -> Union[dict, DataFrame]:\n",
    "    \"\"\"\n",
    "\n",
    "    Args:\n",
    "        data: 数据\n",
    "        stats_columns: 需要统计的列名\n",
    "        label_column: 二分类的标签列，默认`label`\n",
    "        group_column: 分组统计依据，例如区分训练组、测试组、验证组。\n",
    "        continuous_columns: 那些列是连续变量，连续变量统计均值方差。\n",
    "        pretty: bool, 是否对结果进行格式美化。\n",
    "\n",
    "    Returns:\n",
    "        stats DataFrame or json\n",
    "\n",
    "    \"\"\"\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from onekey_algo.custom.components.stats import clinic_stats\n",
    "\n",
    "pd.set_option('display.max_rows', None)\n",
    "stats = clinic_stats(data, \n",
    "                     stats_columns= stats_columns,\n",
    "                     label_column=task, \n",
    "                     group_column='group', \n",
    "                     continuous_columns= continuous_columns, \n",
    "                     pretty=True, verbose=False)\n",
    "stats.to_csv(f'data/stats_{task}.csv', index=False, encoding='utf_8_sig')\n",
    "stats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from onekey_algo.custom.components.stats import clinic_stats\n",
    "\n",
    "pd.set_option('display.max_rows', None)\n",
    "stats = clinic_stats(data, \n",
    "                     stats_columns= stats_columns,\n",
    "                     label_column='group', \n",
    "                     group_column=None, \n",
    "                     continuous_columns= continuous_columns, \n",
    "                     pretty=True, verbose=False)\n",
    "stats.to_csv('data/statics.csv', index=False, encoding='utf_8_sig')\n",
    "stats"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 单因素多因素分析\n",
    "\n",
    "单因素，步进多因素分析\n",
    "```python\n",
    "def uni_multi_variable_analysis(data: pd.DataFrame, features: Union[str, List[str]] = None, label_column: str = 'label',\n",
    "                                need_norm: Union[bool, List[bool]] = False, alpha=0.1,\n",
    "                                p_value4multi: float = 0.05, save_dir: Union[str] = None, prefix: str = '',\n",
    "                                **kwargs):\n",
    "    \"\"\"\n",
    "    单因素，步进多因素分析，使用p_value4multi参数指定多因素分析的阈值\n",
    "    Args:\n",
    "        data: 数据\n",
    "        features: 需要分析的特征，默认除了ID和label_column列，其他的特征都进行分析。\n",
    "        label_column: 目标列\n",
    "        need_norm: 是否标准化所有分析的数据, 默认为False\n",
    "        alpha: CI alpha, alpha/2 %；默认为0.1即95% CI\n",
    "        p_value4multi: 参数指定多因素分析的阈值，默认为0.05\n",
    "        save_dir: 保存位置\n",
    "        prefix: 前缀\n",
    "        **kwargs:\n",
    "\n",
    "    Returns:\n",
    "\n",
    "    \"\"\"\n",
    " ```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from onekey_algo.custom.components.comp1 import uni_multi_variable_analysis                        \n",
    "\n",
    "uni_multi_variable_analysis(data[data['group'] == 'train'], stats_columns, save_dir='img', p_value4multi=p_value, algo='logit')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "uni_v = pd.read_csv('img/multivariable_reg.csv')\n",
    "uni_v = uni_v[uni_v['p_value'] <= 0.05]\n",
    "sel_data = data[['ID'] + list(uni_v['feature_name']) + ['group', 'label']]\n",
    "sel_data.to_csv('data/clinic_sel.csv', index=False)\n",
    "sel_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "uni = pd.read_csv('img/univariable_reg.csv')\n",
    "uni = uni[[c for c in uni if ('OR' in c and 'Log' not in c) or c in ['feature_name', 'p_value']]]\n",
    "multi = pd.read_csv('img/multivariable_reg.csv')\n",
    "multi = multi[[c for c in uni if ('OR' in c and 'Log' not in c) or c in ['feature_name', 'p_value']]]\n",
    "info = pd.merge(uni, multi, on='feature_name', how='left', suffixes=['_UNI', '_MULTI']).applymap(lambda x: '' if pd.isna(x) else x)\n",
    "info.to_csv('results/unimulti.csv', index=False)\n",
    "info"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
